!
!  OAK, Ocean Assimilation Kit
!  Copyright(c) 2002-2011 Alexander Barth and Luc Vandenblucke
!
!  This program is free software; you can redistribute it and/or
!  modify it under the terms of the GNU General Public License
!  as published by the Free Software Foundation; either version 2
!  of the License, or (at your option) any later version.
!
!  This program is distributed in the hope that it will be useful,
!  but WITHOUT ANY WARRANTY; without even the implied warranty of
!  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!  GNU General Public License for more details.
!
!  You should have received a copy of the GNU General Public License
!  along with this program; if not, write to the Free Software
!  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
!

! include the fortran preprocessor definitions
#include "ppdef.h"

module matoper

type SparseMatrix
  integer          :: nz,m,n
  integer, pointer :: i(:),j(:)
  real, pointer    :: s(:)
end type


! A B

interface operator(.x.)
  module procedure              &
! single precision
          smat_mult_smat,       &
    ssparsemat_mult_smat,       &
          smat_mult_ssparsemat, &
          smat_mult_svec,       &
    ssparsemat_mult_svec,       &
    ssparsemat_mult_ssparsemat, &
          svec_mult_smat,       &
          svec_mult_ssparsemat, &
          svec_mult_svec,       &
! double precision
          dmat_mult_dmat,       &
    ssparsemat_mult_dmat,       &
          dmat_mult_ssparsemat, &
          dmat_mult_dvec,       &
    ssparsemat_mult_dvec,       &
          dvec_mult_dmat,       &
          dvec_mult_dvec
end interface




! A B'

interface operator(.xt.)
  module procedure &
! single precision
    smat_mult_smatT, &
    smat_mult_ssparsematT, &
    svec_mult_svecT, &
! double precision
    dmat_mult_dmatT, &
    dmat_mult_ssparsematT, &
    dvec_mult_dvecT
end interface

! A' B

interface operator(.tx.)
  module procedure &
! single precision
    smatT_mult_smat, &
    smatT_mult_svec, &
! double precision
    dmatT_mult_dmat, &
    dmatT_mult_dvec
end interface

! diag(A) B

interface operator(.dx.)
  module procedure & 
! single precision
    sdiag_mult_smat, &
    sdiag_mult_svec, &
! double precision
    ddiag_mult_dmat, &
    ddiag_mult_dvec
end interface

! A diag(B)

interface operator(.xd.)
  module procedure  &
! single precision
     smat_mult_sdiag, &
! double precision
     dmat_mult_ddiag
end interface


interface inv
  module procedure &
    sinv, &
    dinv
end interface

interface det
  module procedure &
    sdet, &
    ddet
end interface

! gesvd obsolet singular value decomposition interface, use svd
interface gesvd
  module procedure &
    gesvd_sf90, &
    gesvd_df90
end interface
 
interface svd
  module procedure &
    svd_sf90, &
    svd_df90
end interface
  
interface ssvd
  module procedure &
    ssvd_singleprec, &
    ssvd_doubleprec
end interface
  

interface symeig
  module procedure &
    symeig_sf90, &
    symeig_df90
end interface


interface eye
  module procedure &
!    seye, &
    deye
end interface

interface trace
  module procedure &
    strace, &
    dtrace
end interface

interface diag
  module procedure &
    sdiag, &
    ddiag, &
    cdiag    
end interface

contains

!----------------------------------------------------------
!
! default precision
!
!----------------------------------------------------------


!_______________________________________________________
!
! create the identity matrix
!

 function seye(n) result(E)
  implicit none
  integer, intent(in) :: n
  real(4) :: E(n,n)
  integer :: i

  E = 0.

  do i=1,n
    E(i,i) = 1.
  end do

 end function 

!_______________________________________________________
!
! create a matrix filled with 1
!

 function ones(n,m) result(E)
  implicit none
  integer, intent(in) :: n,m
  real :: E(n,m)

  E = 1.
 end function 

!_______________________________________________________
!
! create a matrix filled with 1
!

 function zeros(n,m) result(E)
  implicit none
  integer, intent(in) :: n,m
  real :: E(n,m)

  E = 0.
 end function 


!_______________________________________________________
!
! create a matrix filled with uniformly distributed 
! random number between 0 and 1
!

 function rand(n,m) result(E)
  implicit none
  integer, intent(in) :: n,m
  real :: E(n,m)

  call random_number(E)
 end function 

!_______________________________________________________
!
! create a matrix filled with gaussian distributed 
! random number with 0 mean and 1 standard deviation
!
! FIXME: implement the "Box-Muller transformation"
!

 function randn(n,m) result(E)
  implicit none
  integer, intent(in) :: n,m
  real :: E(n,m)

  integer :: i

  E = rand(n,m);
  do i=1,11
    E = E + rand(n,m);
  end do

  E = E-6
 end function 

!_______________________________________________________
!

function randUnitVector(n) result(w)
implicit none
integer, intent(in) :: n
real :: w(n)

w = reshape(randn(n,1),(/ n /));
w = w/sqrt(sum(w*w));

end function

!_______________________________________________________
!
! generates n-1 orthonormal vectors perpendicular to w
!
! Hoteit et al, 2002
! http://dx.doi.org/10.1016/S0924-7963(02)00129-X

function perpSpace(w) result(H)
implicit none
real, intent(in) :: w(:)
real :: H(size(w),size(w)-1)

real :: alpha
integer :: n,i,j

n = size(w);

H = 0
alpha = - 1/(abs(w(n))+1);

do j=1,n-1
  do i=1,n-1
    H(i,j) = alpha * w(i)*w(j);
    if (i.eq.j) then
      H(i,j) = H(i,j)+1;
    end if
  end do
end do

do j=1,n-1
    H(n,j) = alpha * (w(n)+sign(1.,w(n)))*w(j);
end do

end function

!_______________________________________________________
!


function randOrthMatrix(n) result(Omega)
implicit none
integer, intent(in) :: n
real :: Omega(n,n)

integer :: i
real :: w(n)

Omega = 0

Omega(1:1,1) = randUnitVector(1);

do i=2,n
  w(1:i) = randUnitVector(i);

  Omega(1:i,1:i-1) =  (perpSpace(w(1:i))).x.(Omega(1:i-1,1:i-1));
  Omega(1:i,i) =  w(1:i);
end do

end function

!_______________________________________________________
!
! create a diagonal complex matrix
!

 function cdiag(d) result(A)
  implicit none
  complex, intent(in) :: d(:)
  complex :: A(size(d),size(d))
  integer :: i

  A = 0.

  do i=1,size(d)
    A(i,i) = d(i)
  end do

 end function cdiag

!_______________________________________________________
!
! S = SPARSE(X) converts a full matrix to sparse form by
!    squeezing out any zero elements.

function sparse(X) result(S)
implicit none
real, intent(in)   :: X(:,:)
type(SparseMatrix) :: S

integer :: i,j,k

S%m = size(X,1)
S%n = size(X,2)
S%nz = count(X.ne.0)

allocate(S%i(S%nz),S%j(S%nz),S%s(S%nz))
k = 1

 do j=1,S%m
   do i=1,S%n
     if (X(i,j).ne.0) then
       S%i(k) = i
       S%j(k) = j
       S%s(k) = X(i,j)
       k=k+1
     end if
   end do
 end do

end function

!_______________________________________________________
!
! S = full(X) converts a sparse matrix to a dens matrix

function full(S) result(X)
 implicit none
 type(SparseMatrix),intent(in) :: S
 real   :: X(S%m,S%n)
 integer :: k
 X = 0

 do k=1,S%nz
   X(S%i(k),S%j(k)) = S%s(k)
 end do
end function full





!_______________________________________________________
!
! operator:  .x.
!_______________________________________________________
!

function ssparsemat_mult_ssparsemat(A,B) result(C)
 implicit none
 type(SparseMatrix), intent(in) :: A
 type(SparseMatrix), intent(in) :: B
 type(SparseMatrix) :: C
 integer :: k,l,nz,l1,l2,llower,lupper

 if (A%n.ne.B%m) then
   write(stderr,*) 'ssparsemat_mult_ssparsemat: size not conform: A.x.B '
   write(stderr,*) 'shape(A) ',A%m,A%n
   write(stderr,*) 'shape(B) ',B%m,B%n
   stop
 end if

 C%m = A%m
 C%n = B%n

 ! count
!!$  nz = 0
!!$  do k=1,A%nz
!!$    do l=1,B%nz
!!$      if (A%j(k).eq.B%i(l)) nz = nz+1
!!$    end do
!!$  end do

 ! borne sup.
 nz = A%nz+B%nz
 nz = 2714888
 write(6,*) 'nz ',nz
 allocate(C%i(nz),C%j(nz),C%s(nz))
 nz=0

 if (all(B%i(2:B%nz).ge.B%i(:B%nz-1))) then
   ! optimized version
  write(stdout,*) 'optimised version.'

   search: do k=1,A%nz
     l1 = 1
     l2 = B%nz
     ! quick cycle if possible
     if (A%j(k).lt.B%i(l1).or.A%j(k).gt.B%i(l2)) cycle search
!     if (all(A%j(k).ne.B%i)) cycle

     ! dichotomic search for llower
     l1 = 1
     l2 = B%nz
     llower = (l1+l2)/2

     do while ( &
     ! stop criteria
       .not.(B%i(llower).eq.A%j(k).and.(llower.eq.1.or.B%i(llower-1).ne.A%j(k))))

       if (B%i(llower).ge.A%j(k)) then
         l2 = llower-1
       else
         l1 = llower+1
       end if
       llower = (l1+l2)/2

       if (l2.lt.l1) then
          cycle search
       end if
     end do

     ! dichotomic search for lupper
     l1 = 1
     l2 = B%nz
     lupper = (l1+l2)/2

     do while ( &
     ! stop criteria
       .not.(B%i(lupper).eq.A%j(k).and.(lupper.eq.B%nz.or.B%i(lupper+1).ne.A%j(k))))

       if (B%i(lupper).le.A%j(k)) then
         l1 = lupper+1
       else
         l2 = lupper-1
       end if
       lupper = (l1+l2)/2
     end do

     do l=llower,lupper
       nz = nz+1

       if (nz.gt.size(C%i)) then
         write(stderr,*) 'Error: sorry buffer to small'
         stop
       end if
       C%i(nz) = A%i(k)
       C%j(nz) = B%j(l)
       C%s(nz) = A%s(k)*B%s(l)
     end do

         if (mod(k,100).eq.0) then
           write(stderr,*) 'nz ',k,A%nz
         end if

   end do search

 else
   ! general version: take a cofe
   write(stdout,*) 'Warning: unoptimised version.'
   do k=1,A%nz
     do l=1,B%nz
       if (A%j(k).eq.B%i(l)) then
         nz = nz+1
         C%i(nz) = A%i(k)
         C%j(nz) = B%j(l)
         C%s(nz) = A%s(k)*B%s(l)
         if (mod(nz,100).eq.0) then
           write(stderr,*) 'nz ',nz
         end if
       end if
     end do
   end do
 end if
 C%nz = nz

end function ssparsemat_mult_ssparsemat

!--------------------------------------------
 


  subroutine permute(indeces,x,y)
   implicit none
   integer, intent(in) :: indeces(:)
   real, intent(in) :: x(size(indeces))
   real, intent(out) :: y(size(indeces))

   integer :: i
   real :: tmp(size(indeces))

   do i=1,size(indeces)
     tmp(i) = x(indeces(i))    
   end do

   do i=1,size(indeces)
     y(i) = tmp(i)
   end do

  end subroutine permute

!--------------------------------------------

  subroutine ipermute(indeces,x,y)
   implicit none
   integer, intent(in) :: indeces(:)
   real, intent(in) :: x(size(indeces))
   real, intent(out) :: y(size(indeces))

   integer :: i
   real :: tmp(size(indeces))

   do i=1,size(indeces)
     tmp(indeces(i)) = x(i)    
   end do

   do i=1,size(indeces)
     y(i) = tmp(i)
   end do


  end subroutine ipermute



!--------------------------------------------
!
! single precision
!
!--------------------------------------------



#define REAL_TYPE real(kind=4)

! blas and lapack subroutines

#define gemv_TYPE sgemv
#define getrf_TYPE sgetrf
#define getri_TYPE sgetri
#define gesvd_TYPE sgesvd
#define spevx_TYPE sspevx
#define copy_TYPE scopy
#define lamch_TYPE slamch
#define dot_TYPE sdot
#define gemm_TYPE sgemm
#define syevx_TYPE ssyevx

#define diag_TYPE sdiag
#define trace_TYPE strace

! operators

#define mat_mult_mat_TYPE smat_mult_smat
#define mat_mult_vec_TYPE smat_mult_svec
#define vec_mult_mat_TYPE svec_mult_smat
#define vec_mult_vec_TYPE svec_mult_svec

#define ssparsemat_mult_vec_TYPE ssparsemat_mult_svec
#define ssparsemat_mult_mat_TYPE ssparsemat_mult_smat
#define mat_mult_ssparsemat_TYPE smat_mult_ssparsemat
#define vec_mult_ssparsemat_TYPE svec_mult_ssparsemat

#define mat_mult_matT_TYPE smat_mult_smatT
#define mat_mult_ssparsematT_TYPE smat_mult_ssparsematT
#define vec_mult_vecT_TYPE svec_mult_svecT
#define matT_mult_mat_TYPE smatT_mult_smat
#define matT_mult_vec_TYPE smatT_mult_svec

#define diag_mult_mat_TYPE sdiag_mult_smat
#define diag_mult_vec_TYPE sdiag_mult_svec
#define mat_mult_diag_TYPE smat_mult_sdiag

! subroutines

#define inv_TYPE sinv
#define det_TYPE sdet
#define svd_TYPE svd_sf90
#define ssvd_TYPE ssvd_singleprec
#define gesvd_f90_TYPE gesvd_sf90 
#define symeig_TYPE symeig_sf90

#include "matoper_inc.F90"




! undefine all macros

#undef REAL_TYPE

! blas and lapack subroutines

#undef gemv_TYPE
#undef getrf_TYPE
#undef getri_TYPE
#undef gesvd_TYPE
#undef spevx_TYPE
#undef copy_TYPE
#undef ssvd_TYPE
#undef lamch_TYPE
#undef dot_TYPE
#undef gemm_TYPE
#undef syevx_TYPE

#undef diag_TYPE
#undef trace_TYPE

! operators


#undef mat_mult_mat_TYPE
#undef mat_mult_vec_TYPE
#undef vec_mult_mat_TYPE
#undef vec_mult_vec_TYPE

#undef ssparsemat_mult_vec_TYPE
#undef ssparsemat_mult_mat_TYPE
#undef mat_mult_ssparsemat_TYPE
#undef vec_mult_ssparsemat_TYPE

#undef mat_mult_matT_TYPE
#undef mat_mult_ssparsematT_TYPE
#undef vec_mult_vecT_TYPE
#undef matT_mult_mat_TYPE
#undef matT_mult_vec_TYPE

#undef diag_mult_mat_TYPE
#undef diag_mult_vec_TYPE
#undef mat_mult_diag_TYPE

! subroutines

#undef inv_TYPE
#undef det_TYPE
#undef svd_TYPE
#undef ssvd_TYPE
#undef gesvd_f90_TYPE
#undef symeig_TYPE





!--------------------------------------------
!
! double precision
!
!--------------------------------------------


!_______________________________________________________
!
! create the identity matrix
!

 function deye(n) result(E)
  implicit none
  integer, intent(in) :: n
  real(8) :: E(n,n)
  integer :: i

  E = 0.

  do i=1,n
    E(i,i) = 1.
  end do

 end function 

! define macros for inclusion of matoper_inc.F90

#define REAL_TYPE real(kind=8)

! blas and lapack subroutines

#define gemv_TYPE DGEMV
#define getrf_TYPE dgetrf
#define getri_TYPE dgetri
#define gesvd_TYPE dgesvd
#define spevx_TYPE dspevx
#define copy_TYPE dcopy
#define lamch_TYPE dlamch
#define dot_TYPE ddot
#define gemm_TYPE dGEMM
#define syevx_TYPE dsyevx

#define diag_TYPE ddiag
#define trace_TYPE dtrace

! operators

#define mat_mult_mat_TYPE dmat_mult_dmat
#define mat_mult_vec_TYPE dmat_mult_dvec
#define vec_mult_mat_TYPE dvec_mult_dmat
#define vec_mult_vec_TYPE dvec_mult_dvec

#define ssparsemat_mult_vec_TYPE ssparsemat_mult_dvec
#define ssparsemat_mult_mat_TYPE ssparsemat_mult_dmat
#define mat_mult_ssparsemat_TYPE dmat_mult_ssparsemat
#define vec_mult_ssparsemat_TYPE dvec_mult_ssparsemat

#define mat_mult_matT_TYPE dmat_mult_dmatT
#define mat_mult_ssparsematT_TYPE dmat_mult_ssparsematT
#define vec_mult_vecT_TYPE dvec_mult_dvecT
#define matT_mult_mat_TYPE dmatT_mult_dmat
#define matT_mult_vec_TYPE dmatT_mult_dvec

#define diag_mult_mat_TYPE ddiag_mult_dmat
#define diag_mult_vec_TYPE ddiag_mult_dvec
#define mat_mult_diag_TYPE dmat_mult_ddiag

! subroutines

#define inv_TYPE dinv
#define det_TYPE ddet
#define svd_TYPE svd_df90
#define ssvd_TYPE ssvd_doubleprec
#define gesvd_f90_TYPE gesvd_df90 
#define symeig_TYPE symeig_df90


#include "matoper_inc.F90"


!!$
!!$!_______________________________________________________
!!$!
!!$
!!$subroutine gesvd_sf90(JOBU,JOBVT,A,S,U,VT,INFO)
!!$implicit none
!!$character, intent(in) :: jobu,jobvt
!!$real(kind=4), intent(in) :: A(:,:)
!!$real(kind=4), intent(out) :: S(:), U(:,:), VT(:,:)
!!$integer, intent(out) :: info
!!$
!!$real(kind=4) :: dummy,rlwork
!!$integer :: lwork
!!$real(kind=4), allocatable :: work(:)
!!$
!!$#ifndef ALLOCATE_LOCAL_VARS
!!$  real(kind=4) :: copyA(size(A,1),size(A,2))
!!$#else
!!$  real(kind=4), pointer :: copyA(:,:)
!!$  allocate(copyA(size(A,1),size(A,2)))
!!$#endif
!!$
!!$
!!$copyA = A
!!$
!!$call SGESVD( JOBU, JOBVT,size(A,1),size(A,2), copyA, size(A,1), &
!!$   S, U,size(U,1), VT,size(VT,1), &
!!$     rlWORK, -1, INFO)
!!$lwork = rlwork+0.5
!!$allocate(work(lwork))
!!$
!!$call SGESVD( JOBU, JOBVT,size(A,1),size(A,2), copyA, size(A,1), &
!!$   S, U,size(U,1), VT,size(VT,1), &
!!$     WORK, lwork, INFO)
!!$deallocate(work)
!!$#ifdef ALLOCATE_LOCAL_VARS
!!$  deallocate(copyA)
!!$#endif
!!$
!!$end subroutine 
!!$
!!$!_______________________________________________________
!!$!
!!$
!!$subroutine gesvd_df90(JOBU,JOBVT,A,S,U,VT,INFO)
!!$implicit none
!!$character, intent(in) :: jobu,jobvt
!!$real(kind=8), intent(in) :: A(:,:)
!!$real(kind=8), intent(out) :: S(:), U(:,:), VT(:,:)
!!$integer, intent(out) :: info
!!$
!!$real(kind=8) :: dummy,rlwork
!!$integer :: lwork
!!$real(kind=8), allocatable :: work(:)
!!$
!!$#ifndef ALLOCATE_LOCAL_VARS
!!$  real(kind=4) :: copyA(size(A,1),size(A,2))
!!$#else
!!$  real(kind=4), pointer :: copyA(:,:)
!!$  allocate(copyA(size(A,1),size(A,2)))
!!$#endif
!!$
!!$copyA = A
!!$
!!$call DGESVD( JOBU, JOBVT,size(A,1),size(A,2), copyA, size(A,1), &
!!$   S, U,size(U,1), VT,size(VT,1), &
!!$     rlWORK, -1, INFO)
!!$
!!$lwork = rlwork+0.5
!!$allocate(work(lwork))
!!$
!!$call DGESVD( JOBU, JOBVT,size(A,1),size(A,2), copyA, size(A,1), &
!!$   S, U,size(U,1), VT,size(VT,1), &
!!$     WORK, lwork, INFO)
!!$deallocate(work)
!!$
!!$#ifdef ALLOCATE_LOCAL_VARS
!!$  deallocate(copyA)
!!$#endif
!!$
!!$
!!$end subroutine 
!!$
!!$
!!$
!!$!_______________________________________________________
!!$!
!!$! computes eigenvalue/-vector of a symetric matrix
!!$!
!!$! A = V' diag(E) V 
!!$
!!$subroutine symeig_sf90(A,E,V,nbiggest,nsmallest,indices,INFO)
!!$implicit none
!!$real(kind=4), intent(in)  :: A(:,:)
!!$real(4),         intent(out) :: E(size(A,1))
!!$integer, optional, intent(in) :: nbiggest, nsmallest,indices(2)
!!$real(4), optional, target, intent(out) :: V(:,:)
!!$integer, optional, intent(out) :: info
!!$
!!$character :: jobz
!!$real(kind=4), pointer :: pV(:,:)
!!$real(kind=4) :: rlwork, tmp
!!$integer :: lwork, myinfo, N, iwork(5*size(A,1)), ifail(size(A,1)), i,j
!!$real(kind=4), allocatable :: work(:)
!!$
!!$! LAPACK Machine precision routine
!!$
!!$real :: slamch
!!$integer :: r,idummy,ind(2)
!!$
!!$
!!$#ifndef ALLOCATE_LOCAL_VARS
!!$  real(kind=4) :: copyA(size(A,1),size(A,2))
!!$#else
!!$  real(kind=4), pointer :: copyA(:,:)
!!$  allocate(copyA(size(A,1),size(A,2)))
!!$#endif
!!$
!!$
!!$
!!$jobz='n'
!!$N = size(A,1)
!!$r = n
!!$
!!$if (present(V)) then
!!$  jobz='v'
!!$  pV => V
!!$else
!!$  allocate(pV(1,1))
!!$end if
!!$
!!$ind = (/ 1,n /)
!!$
!!$if (present(nbiggest))  ind = (/ n-nbiggest+1,n /)
!!$if (present(nsmallest)) ind = (/ 1,nsmallest /)
!!$if (present(indices))   ind = indices
!!$
!!$! protect content of A
!!$
!!$copyA = A
!!$
!!$! determine the optimal size of work
!!$
!!$call SSYEVX(JOBZ,'I','U',n,copyA,n,-1.,-1.,ind(1),ind(2),     &
!!$     2*SLAMCH('S'),idummy,E,pV,n,rlWORK,-1, IWORK,   &
!!$     IFAIL, myINFO )
!!$
!!$lwork = rlwork+0.5
!!$allocate(work(lwork))
!!$
!!$call SSYEVX(JOBZ,'I','U',n,copyA,n,-1.,-1.,ind(1),ind(2),     &
!!$     2*SLAMCH('S'),idummy,E,pV,n, WORK, LWORK, IWORK,   &
!!$     IFAIL, myINFO )
!!$
!!$  if (present(nbiggest)) then
!!$! sort in descending order
!!$    do i=1,nbiggest/2
!!$      tmp = E(i)
!!$      E(i) = E(nbiggest-i+1)
!!$      E(nbiggest-i+1) = tmp
!!$
!!$      if (present(V)) then
!!$       do j=1,n
!!$        tmp = V(j,i) 
!!$        V(j,i) = V(j,nbiggest-i+1)
!!$        V(j,nbiggest-i+1) = tmp
!!$       end do
!!$      end if
!!$    end do
!!$  end if
!!$
!!$deallocate(work)
!!$#ifdef ALLOCATE_LOCAL_VARS
!!$  deallocate(copyA)
!!$#endif
!!$
!!$
!!$if (.not.present(V)) deallocate(pV)
!!$if (present(info)) info = myinfo
!!$end subroutine
!!$
!!$!_______________________________________________________
!!$!


end module
